import cuid
import asyncio
import json
import os
from typing import Any, List, Callable, Optional, Tuple, Dict, Awaitable, Union
from traceloop.sdk.client.http import HTTPClient
from traceloop.sdk.datasets.datasets import Datasets
from traceloop.sdk.evaluator.evaluator import Evaluator, validate_task_output
from traceloop.sdk.experiment.model import (
    InitExperimentRequest,
    ExperimentInitResponse,
    CreateTaskRequest,
    CreateTaskResponse,
    EvaluatorSpec,
    TaskResponse,
    RunInGithubRequest,
    RunInGithubResponse,
    TaskResult,
    GithubContext,
)
from traceloop.sdk.evaluator.config import EvaluatorDetails
import httpx


class Experiment:
    """Main Experiment class for creating experiment contexts"""

    _datasets: Datasets
    _evaluator: Evaluator
    _http_client: HTTPClient

    def __init__(self, http_client: HTTPClient, async_http_client: httpx.AsyncClient, experiment_slug: str):
        self._datasets = Datasets(http_client)
        self._evaluator = Evaluator(async_http_client)
        self._http_client = http_client
        self._experiment_slug = experiment_slug

    async def run(
        self,
        task: Callable[[Optional[Dict[str, Any]]], Awaitable[Dict[str, Any]]],
        evaluators: List[EvaluatorSpec],
        dataset_slug: Optional[str] = None,
        dataset_version: Optional[str] = None,
        experiment_slug: Optional[str] = None,
        experiment_metadata: Optional[Dict[str, Any]] = None,
        related_ref: Optional[Dict[str, str]] = None,
        aux: Optional[Dict[str, str]] = None,
        stop_on_error: bool = False,
        wait_for_results: bool = True,
    ) -> Tuple[List[TaskResponse], List[str]] | RunInGithubResponse:
        """Run an experiment with the given task and evaluators

        Args:
            task: Async function to run on each dataset row
            evaluators: List of evaluator slugs or EvaluatorDetails objects to run
            dataset_slug: Slug of the dataset to use
            dataset_version: Version of the dataset to use
            experiment_slug: Slug for this experiment run
            experiment_metadata: Metadata for this experiment (an experiment holds all the experiment runs)
            related_ref: Related reference for this experiment run
            aux: Auxiliary information for this experiment run
            stop_on_error: Whether to stop on first error (default: False)
            wait_for_results: Whether to wait for async tasks to complete (default: True)
        Returns:
            Tuple of (results, errors). Returns ([], []) if wait_for_results is False
        """
        if os.getenv("GITHUB_ACTIONS"):
            return await self._run_in_github(
                task=task,
                dataset_slug=dataset_slug,
                dataset_version=dataset_version,
                evaluators=evaluators,
                experiment_slug=experiment_slug,
                related_ref=related_ref,
                aux=aux,
            )
        else:
            return await self._run_locally(
                task=task,
                evaluators=evaluators,
                dataset_slug=dataset_slug,
                dataset_version=dataset_version,
                experiment_slug=experiment_slug,
                experiment_metadata=experiment_metadata,
                related_ref=related_ref,
                aux=aux,
                stop_on_error=stop_on_error,
                wait_for_results=wait_for_results,
            )

    async def _run_locally(
        self,
        task: Callable[[Optional[Dict[str, Any]]], Awaitable[Dict[str, Any]]],
        evaluators: List[EvaluatorSpec],
        dataset_slug: Optional[str] = None,
        dataset_version: Optional[str] = None,
        experiment_slug: Optional[str] = None,
        experiment_metadata: Optional[Dict[str, Any]] = None,
        related_ref: Optional[Dict[str, str]] = None,
        aux: Optional[Dict[str, str]] = None,
        stop_on_error: bool = False,
        wait_for_results: bool = True,
    ) -> Tuple[List[TaskResponse], List[str]]:
        """Run an experiment with the given task and evaluators

        Args:
            dataset_slug: Slug of the dataset to use
            task: Async function to run on each dataset row
            evaluators: List of evaluator slugs to run
            experiment_slug: Slug for this experiment run
            experiment_metadata: Metadata for this experiment (an experiment holds all the experiment runs)
            related_ref: Related reference for this experiment run
            aux: Auxiliary information for this experiment run
            stop_on_error: Whether to stop on first error (default: False)
            wait_for_results: Whether to wait for async tasks to complete (default: True)

        Returns:
            Tuple of (results, errors). Returns ([], []) if wait_for_results is False
        """

        if not experiment_slug:
            experiment_slug = self._experiment_slug or "exp-" + str(cuid.cuid())[:11]

        experiment_run_metadata = {
            key: value
            for key, value in [("related_ref", related_ref), ("aux", aux)]
            if value is not None
        }

        # Convert evaluators to tuples of (slug, version, config)
        evaluator_details: Optional[List[Tuple[str, Optional[str], Optional[Dict[str, Any]]]]] = None
        if evaluators:
            evaluator_details = []
            for evaluator in evaluators:
                if isinstance(evaluator, str):
                    # Simple string slug
                    evaluator_details.append((evaluator, None, None))
                elif isinstance(evaluator, EvaluatorDetails):
                    # EvaluatorDetails object with config
                    evaluator_details.append((evaluator.slug, evaluator.version, evaluator.config))

        experiment = self._init_experiment(
            experiment_slug,
            dataset_slug=dataset_slug,
            dataset_version=dataset_version,
            evaluator_slugs=[slug for slug, _, _ in evaluator_details]
            if evaluator_details
            else None,
            experiment_metadata=experiment_metadata,
            experiment_run_metadata=experiment_run_metadata,
        )

        run_id = experiment.run.id

        rows = []
        if dataset_slug and dataset_version:
            jsonl_data = self._datasets.get_version_jsonl(dataset_slug, dataset_version)
            rows = self._parse_jsonl_to_rows(jsonl_data)

        results: List[TaskResponse] = []
        errors: List[str] = []

        evaluators_to_validate = [evaluator for evaluator in evaluators if isinstance(evaluator, EvaluatorDetails)]

        async def run_single_row(row: Optional[Dict[str, Any]]) -> TaskResponse:
            try:
                task_result = await task(row)

                # Validate task output with EvaluatorDetails with required_input_fields from evaluators list
                if evaluators_to_validate:
                    validate_task_output(task_result, evaluators_to_validate)

                task_id = self._create_task(
                    experiment_slug=experiment_slug,
                    experiment_run_id=run_id,
                    task_input=row,
                    task_output=task_result,
                ).id

                eval_results: Dict[str, Union[Dict[str, Any], str]] = {}
                if evaluator_details:
                    for evaluator_slug, evaluator_version, evaluator_config in evaluator_details:
                        try:
                            if wait_for_results:
                                eval_result = (
                                    await self._evaluator.run_experiment_evaluator(
                                        evaluator_slug=evaluator_slug,
                                        evaluator_version=evaluator_version,
                                        evaluator_config=evaluator_config,
                                        task_id=task_id,
                                        experiment_id=experiment.experiment.id,
                                        experiment_run_id=run_id,
                                        input=task_result,
                                        timeout_in_sec=120,
                                    )
                                )
                                eval_results[evaluator_slug] = eval_result.result
                            else:
                                await self._evaluator.trigger_experiment_evaluator(
                                    evaluator_slug=evaluator_slug,
                                    evaluator_version=evaluator_version,
                                    evaluator_config=evaluator_config,
                                    task_id=task_id,
                                    experiment_id=experiment.experiment.id,
                                    experiment_run_id=run_id,
                                    input=task_result,
                                )

                                msg = f"Triggered execution of {evaluator_slug}"
                                eval_results[evaluator_slug] = msg

                        except Exception as e:
                            error_msg = f"Error: {str(e)}"
                            eval_results[evaluator_slug] = error_msg
                            # Log the error so user can see it
                            print(f"\033[91m❌ Evaluator '{evaluator_slug}' failed: {str(e)}\033[0m")

                return TaskResponse(
                    task_result=task_result,
                    evaluations=eval_results,
                )
            except Exception as e:
                error_msg = f"Error processing row: {str(e)}"
                # Print error to console so user can see it
                print(f"\033[91m❌ Task execution failed: {str(e)}\033[0m")
                if stop_on_error:
                    raise e
                return TaskResponse(error=error_msg)

        semaphore = asyncio.Semaphore(50)

        async def run_with_semaphore(row: Optional[Dict[str, Any]]) -> TaskResponse:
            async with semaphore:
                return await run_single_row(row)

        tasks = [asyncio.create_task(run_with_semaphore(row)) for row in rows]

        if not wait_for_results:
            # Still need to execute tasks to trigger evaluators, but don't wait for completion
            await asyncio.gather(*tasks, return_exceptions=True)
            return [], []

        for completed_task in asyncio.as_completed(tasks):
            try:
                result = await completed_task
                if result.error:
                    errors.append(result.error)
                else:
                    results.append(result)
            except Exception as e:
                error_msg = f"Task execution error: {str(e)}"
                errors.append(error_msg)
                if stop_on_error:
                    break

        return results, errors

    async def _run_in_github(
        self,
        task: Callable[[Optional[Dict[str, Any]]], Awaitable[Dict[str, Any]]],
        evaluators: List[EvaluatorSpec],
        dataset_slug: Optional[str] = None,
        dataset_version: Optional[str] = None,
        experiment_slug: Optional[str] = None,
        experiment_metadata: Optional[Dict[str, Any]] = None,
        related_ref: Optional[Dict[str, str]] = None,
        aux: Optional[Dict[str, str]] = None,
    ) -> RunInGithubResponse:
        """Execute tasks locally and submit results to backend for GitHub CI/CD

        This method:
        1. Fetches the dataset
        2. Executes all tasks locally
        3. Sends task results to backend
        4. Backend runs evaluators and posts PR comment

        Args:
            task: Async function to run on each dataset row
            dataset_slug: Slug of the dataset to use
            dataset_version: Version of the dataset
            evaluators: List of evaluator slugs or (slug, version) tuples to run
            experiment_slug: Slug for this experiment run
            experiment_metadata: Metadata for this experiment (an experiment holds all the experiment runs)
            related_ref: Additional reference information for this experiment run
            aux: Auxiliary information for this experiment run

        Returns:
            RunInGithubResponse with experiment_id, run_id, and status

        Raises:
            RuntimeError: If not running in GitHub Actions environment
            Exception: If the API request fails
        """

        # Check if running in GitHub Actions
        if not os.getenv("GITHUB_ACTIONS"):
            raise RuntimeError(
                "run_in_github() can only be used in GitHub Actions CI/CD environment. "
                "To run experiments locally, use the run() method instead."
            )

        if not experiment_slug:
            experiment_slug = self._experiment_slug or "exp-" + str(cuid.cuid())[:11]

        # Fetch dataset rows
        rows = []
        if dataset_slug and dataset_version:
            jsonl_data = self._datasets.get_version_jsonl(dataset_slug, dataset_version)
            rows = self._parse_jsonl_to_rows(jsonl_data)

        task_results = await self._execute_tasks(rows, task, evaluators)

        # Construct GitHub context
        repository = os.getenv("GITHUB_REPOSITORY")
        server_url = os.getenv("GITHUB_SERVER_URL", "https://github.com")
        github_event_name = os.getenv("GITHUB_EVENT_NAME", "")

        # Verify this is running in a pull request context
        if github_event_name != "pull_request":
            raise RuntimeError(
                f"run_in_github() can only be used in pull_request workflow. "
                f"Current event: {github_event_name}. "
                "To run experiments locally, use the run() method instead."
            )

        # Extract PR number from GITHUB_REF (format: "refs/pull/123/merge")
        github_ref = os.getenv("GITHUB_REF", "")
        pr_number = None
        if github_ref.startswith("refs/pull/"):
            pr_number = github_ref.split("/")[2]

        if not repository or not github_ref or not pr_number:
            raise RuntimeError(
                "GITHUB_REPOSITORY and GITHUB_REF must be set in the environment. "
                "To run experiments locally, use the run() method instead."
            )

        pr_url = f"{server_url}/{repository}/pull/{pr_number}"

        github_context = GithubContext(
            repository=repository,
            pr_url=pr_url,
            commit_hash=os.getenv("GITHUB_SHA", ""),
            actor=os.getenv("GITHUB_ACTOR", ""),
        )

        experiment_metadata = dict(
            experiment_metadata or {},
            created_from="github"
        )

        experiment_run_metadata = {
            key: value
            for key, value in [("related_ref", related_ref), ("aux", aux)]
            if value is not None
        }

        # Extract evaluator slugs
        evaluator_slugs = None
        if evaluators:
            evaluator_slugs = []
            for evaluator in evaluators:
                if isinstance(evaluator, str):
                    evaluator_slugs.append(evaluator)
                elif isinstance(evaluator, EvaluatorDetails):
                    evaluator_slugs.append(evaluator.slug)

        # Prepare request payload
        request_body = RunInGithubRequest(
            experiment_slug=experiment_slug,
            dataset_slug=dataset_slug,
            dataset_version=dataset_version,
            evaluator_slugs=evaluator_slugs,
            task_results=task_results,
            github_context=github_context,
            experiment_metadata=experiment_metadata,
            experiment_run_metadata=experiment_run_metadata,
        )

        response = self._http_client.post(
            "/experiments/run-in-github",
            request_body.model_dump(mode="json", exclude_none=True),
        )

        if response is None:
            raise Exception(
                f"Failed to submit experiment '{experiment_slug}' for GitHub execution. "
            )

        return RunInGithubResponse(**response)

    def _init_experiment(
        self,
        experiment_slug: str,
        dataset_slug: Optional[str] = None,
        dataset_version: Optional[str] = None,
        evaluator_slugs: Optional[List[str]] = None,
        experiment_metadata: Optional[Dict[str, Any]] = None,
        experiment_run_metadata: Optional[Dict[str, Any]] = None,
    ) -> ExperimentInitResponse:
        """Get experiment by slug from API"""
        body = InitExperimentRequest(
            slug=experiment_slug,
            dataset_slug=dataset_slug,
            dataset_version=dataset_version,
            evaluator_slugs=evaluator_slugs,
            experiment_metadata=experiment_metadata,
            experiment_run_metadata=experiment_run_metadata,
        )
        response = self._http_client.put(
            "/experiments/initialize", body.model_dump(mode="json")
        )
        if response is None:
            raise Exception(
                f"Failed to create or fetch experiment with slug '{experiment_slug}'"
            )
        return ExperimentInitResponse(**response)

    def _create_task(
        self,
        experiment_slug: str,
        experiment_run_id: str,
        task_input: Optional[Dict[str, Any]],
        task_output: Dict[str, Any],
    ) -> CreateTaskResponse:
        body = CreateTaskRequest(
            input=task_input,
            output=task_output,
        )
        response = self._http_client.post(
            f"/experiments/{experiment_slug}/runs/{experiment_run_id}/task",
            body.model_dump(mode="json"),
        )
        if response is None:
            raise Exception(f"Failed to create task for experiment '{experiment_slug}'")
        return CreateTaskResponse(**response)

    def _parse_jsonl_to_rows(self, jsonl_data: str) -> List[Dict[str, Any]]:
        """Parse JSONL string into list of {col_name: col_value} dictionaries"""
        rows = []
        lines = jsonl_data.strip().split("\n")

        # Skip the first line (columns definition)
        for line in lines[1:]:
            if line.strip():
                try:
                    row_data = json.loads(line)
                    rows.append(row_data)
                except json.JSONDecodeError:
                    # Skip invalid JSON lines
                    continue

        return rows

    async def _execute_tasks(
        self,
        rows: List[Dict[str, Any]],
        task: Callable[[Optional[Dict[str, Any]]], Awaitable[Dict[str, Any]]],
        evaluators: Optional[List[EvaluatorSpec]] = None,
    ) -> List[TaskResult]:
        """Execute tasks locally with concurrency control

        Args:
            rows: List of dataset rows to process
            task: Function to run on each row
            evaluators: List of evaluators to validate task output against

        Returns:
            List of TaskResult objects with inputs, outputs, and errors
        """
        task_results: List[TaskResult] = []

        # Extract EvaluatorDetails from evaluators list
        evaluators_to_validate = []
        if evaluators:
            for evaluator in evaluators:
                if isinstance(evaluator, EvaluatorDetails):
                    evaluators_to_validate.append(evaluator)

        async def run_single_row(row: Optional[Dict[str, Any]]) -> TaskResult:
            try:
                task_output = await task(row)

                # Validate task output schema on first execution
                if evaluators_to_validate:
                    try:
                        validate_task_output(task_output, evaluators_to_validate)
                    except ValueError as validation_error:
                        print(f"\033[91m❌ Task validation failed: {str(validation_error)}\033[0m")
                        raise ValueError(str(validation_error))

                return TaskResult(
                    input=row,
                    output=task_output,
                )
            except Exception as e:
                if isinstance(e, ValueError):
                    raise e
                print(f"\033[91m❌ Task execution error: {str(e)}\033[0m")
                return TaskResult(
                    input=row,
                    error=str(e),
                )

        # Execute tasks with concurrency control
        semaphore = asyncio.Semaphore(50)

        async def run_with_semaphore(row: Dict[str, Any]) -> TaskResult:
            async with semaphore:
                return await run_single_row(row)

        tasks = [asyncio.create_task(run_with_semaphore(row)) for row in rows]

        for completed_task in asyncio.as_completed(tasks):
            result = await completed_task
            task_results.append(result)

        return task_results
